Variational Autoencoder (VAE) for CIFAR-10 Color Image Synthesis¶

Foundational Models and Generative AI - Assignment 1¶

Objective: Build a VAE that can generate realistic color images from scratch, explore latent space interpolation, and experiment with β-VAE modifications.


I have implemented below components¶

  1. Encoder - Compresses images into latent distributions (μ, σ)
  2. Decoder - Reconstructs images from latent codes
  3. Reparameterization Trick - Enables backprop through sampling
  4. Training Pipeline - With loss visualization
  5. Image Generation - Dream up new images from random noise
  6. Latent Interpolation - Morph between two images smoothly
  7. β-VAE Experiment - Compare β=1 vs β=5

This image explains what exactly I have done¶

I have created streamlit app for VAE simulation¶

  • https://let-see-how-vae-works.streamlit.app/

image.png

1.Import Libraries¶

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("   No GPU detected! Training will be slower...")
Using device: cuda
   GPU: Tesla T4
   Memory: 15.6 GB

2. Load CIFAR-10 Dataset¶

CIFAR-10 has 60,000 32x32 color images in 10 classes. We'll normalize to [-1, 1] range since we're using Tanh in decoder output.

In [2]:
BATCH_SIZE = 128
LATENT_DIM = 128
LEARNING_RATE = 1e-3
EPOCHS = 50

IMG_CHANNELS = 3
IMG_SIZE = 32

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")
100%|██████████| 170M/170M [00:05<00:00, 31.1MB/s] 
Training samples: 50000
Test samples: 10000
Batches per epoch: 391

3. Build the Encoder Network¶

The encoder takes a 32x32x3 image and compresses it down to two vectors:

  • μ (mu): The mean of the latent distribution
  • log(σ²) (logvar): The log-variance (we use log for numerical stability)

Architecture: Conv2d layers with stride=2 to progressively downsample the image.

Input: 32x32x3
  ↓ Conv 3→32, stride 2
16x16x32
  ↓ Conv 32→64, stride 2  
8x8x64
  ↓ Conv 64→128, stride 2
4x4x128
  ↓ Conv 128→256, stride 2
2x2x256 = 1024 features
  ↓ Linear
μ (128) and logvar (128)
In [3]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.flatten_size = 256 * 2 * 2
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flatten_size, latent_dim)
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

print("Encoder defined")
Encoder defined

4. Decoder Network¶

The decoder is essentially the encoder in reverse. It takes a latent vector and upsamples it back to a full image.

Architecture: Using ConvTranspose2d (transposed convolutions) to upsample.

Input: z (128)
  ↓ Linear
1024 → reshape to 2x2x256
  ↓ ConvT 256→128, stride 2
4x4x128
  ↓ ConvT 128→64, stride 2
8x8x64
  ↓ ConvT 64→32, stride 2
16x16x32
  ↓ ConvT 32→3, stride 2
32x32x3 (with Tanh for [-1,1] output)
In [4]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        
        self.latent_dim = latent_dim
        
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256 * 2 * 2),
            nn.ReLU(inplace=True)
        )
        
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), 256, 2, 2)
        x = self.deconv_layers(x)
        return x

print("Decoder defined")
Decoder defined

5. The Complete VAE with Reparameterization Trick¶

This is the most important part! The reparameterization trick allows us to backpropagate through the sampling operation.

The Problem:¶

We want to sample z ~ N(μ, σ²), but sampling is not differentiable!

The Solution (Reparameterization Trick):¶

Instead of sampling z directly, we:

  1. Sample ε ~ N(0, 1) (standard normal)
  2. Compute z = μ + σ × ε

This way, the randomness comes from ε (which doesn't need gradients), and μ and σ are deterministic functions that CAN be optimized!

$$z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$

In [5]:
class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(VAE, self).__init__()
        
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        return z
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_reconstructed = self.decoder(z)
        return x_reconstructed, mu, logvar
    
    def generate(self, num_samples, device):
        z = torch.randn(num_samples, self.latent_dim).to(device)
        with torch.no_grad():
            samples = self.decoder(z)
        return samples

model = VAE(LATENT_DIM).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"VAE model created!")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
VAE model created!
   Total parameters: 1,775,939
   Trainable parameters: 1,775,939

6. Loss Function (ELBO)¶

The VAE loss has two parts:

1. Reconstruction Loss¶

Measures how well the decoder recreates the input image. We use MSE loss.

$$\mathcal{L}_{recon} = ||x - \hat{x}||^2$$

2. KL Divergence Loss¶

Regularizes the latent space to be close to a standard normal distribution N(0, I). This is what makes the latent space smooth and interpolable!

$$\mathcal{L}_{KL} = -\frac{1}{2} \sum_{j=1}^{J} (1 + \log(\sigma_j^2) - \mu_j^2 - \sigma_j^2)$$

Total Loss (ELBO)¶

$$\mathcal{L} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{KL}$$

For standard VAE, β = 1. For β-VAE, we'll try higher values!

In [6]:
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    batch_size = x.size(0)
    
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
    total_loss = recon_loss + beta * kl_loss
    
    return total_loss, recon_loss, kl_loss

print("Loss function defined")
Loss function defined

7. Training Loop¶

Now let's train our VAE! I'll track both losses separately so we can see what's happening.

In [7]:
def train_vae(model, train_loader, epochs, beta=1.0, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    history = {
        'total_loss': [],
        'recon_loss': [],
        'kl_loss': []
    }
    
    model.train()
    
    for epoch in range(epochs):
        epoch_total = 0
        epoch_recon = 0
        epoch_kl = 0
        num_batches = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_idx, (data, _) in enumerate(pbar):
            data = data.to(device)
            
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, beta)
            loss.backward()
            optimizer.step()
            
            epoch_total += loss.item()
            epoch_recon += recon_loss.item()
            epoch_kl += kl_loss.item()
            num_batches += 1
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'recon': f'{recon_loss.item():.4f}',
                'kl': f'{kl_loss.item():.4f}'
            })
        
        avg_total = epoch_total / num_batches
        avg_recon = epoch_recon / num_batches
        avg_kl = epoch_kl / num_batches
        
        history['total_loss'].append(avg_total)
        history['recon_loss'].append(avg_recon)
        history['kl_loss'].append(avg_kl)
        
        print(f"  → Avg Loss: {avg_total:.4f} | Recon: {avg_recon:.4f} | KL: {avg_kl:.4f}")
    
    return history

print("Training function defined")
Training function defined
In [8]:
print("=" * 60)
print("TRAINING STANDARD VAE (β = 1)")
print("=" * 60)

model = VAE(LATENT_DIM).to(device)
history = train_vae(model, train_loader, EPOCHS, beta=1.0, lr=LEARNING_RATE)

print("\nTraining complete!")
============================================================
TRAINING STANDARD VAE (β = 1)
============================================================
Epoch 1/50:   0%|          | 0/391 [00:00<?, ?it/s]
Epoch 1/50: 100%|██████████| 391/391 [00:15<00:00, 25.78it/s, loss=254.9640, recon=214.2250, kl=40.7390]
  → Avg Loss: 381.9684 | Recon: 346.1202 | KL: 35.8482
Epoch 2/50: 100%|██████████| 391/391 [00:13<00:00, 28.65it/s, loss=221.4289, recon=179.6664, kl=41.7625]
  → Avg Loss: 251.6903 | Recon: 208.1166 | KL: 43.5737
Epoch 3/50: 100%|██████████| 391/391 [00:13<00:00, 28.30it/s, loss=230.7552, recon=181.5909, kl=49.1642]
  → Avg Loss: 231.4502 | Recon: 185.6165 | KL: 45.8337
Epoch 4/50: 100%|██████████| 391/391 [00:13<00:00, 27.98it/s, loss=210.9300, recon=162.3524, kl=48.5777]
  → Avg Loss: 222.4549 | Recon: 174.7670 | KL: 47.6879
Epoch 5/50: 100%|██████████| 391/391 [00:14<00:00, 27.71it/s, loss=218.5691, recon=168.8008, kl=49.7683]
  → Avg Loss: 214.3061 | Recon: 164.6694 | KL: 49.6367
Epoch 6/50: 100%|██████████| 391/391 [00:14<00:00, 27.68it/s, loss=228.2435, recon=175.8781, kl=52.3655]
  → Avg Loss: 210.3141 | Recon: 159.4566 | KL: 50.8574
Epoch 7/50: 100%|██████████| 391/391 [00:13<00:00, 28.25it/s, loss=221.1688, recon=167.6095, kl=53.5593]
  → Avg Loss: 207.1925 | Recon: 155.3664 | KL: 51.8262
Epoch 8/50: 100%|██████████| 391/391 [00:14<00:00, 27.60it/s, loss=209.1390, recon=155.9035, kl=53.2355]
  → Avg Loss: 204.6212 | Recon: 151.9302 | KL: 52.6910
Epoch 9/50: 100%|██████████| 391/391 [00:13<00:00, 29.28it/s, loss=207.8464, recon=154.7688, kl=53.0776]
  → Avg Loss: 203.1448 | Recon: 149.8820 | KL: 53.2628
Epoch 10/50: 100%|██████████| 391/391 [00:13<00:00, 28.89it/s, loss=201.8139, recon=146.1966, kl=55.6173]
  → Avg Loss: 201.9444 | Recon: 148.1459 | KL: 53.7986
Epoch 11/50: 100%|██████████| 391/391 [00:13<00:00, 29.54it/s, loss=209.9300, recon=156.3389, kl=53.5910]
  → Avg Loss: 200.9196 | Recon: 146.7985 | KL: 54.1210
Epoch 12/50: 100%|██████████| 391/391 [00:13<00:00, 28.72it/s, loss=186.0677, recon=133.0940, kl=52.9736]
  → Avg Loss: 199.6977 | Recon: 145.1711 | KL: 54.5267
Epoch 13/50: 100%|██████████| 391/391 [00:13<00:00, 29.10it/s, loss=202.0656, recon=144.7834, kl=57.2821]
  → Avg Loss: 198.6460 | Recon: 143.8008 | KL: 54.8452
Epoch 14/50: 100%|██████████| 391/391 [00:13<00:00, 29.03it/s, loss=181.9039, recon=128.6232, kl=53.2807]
  → Avg Loss: 198.0756 | Recon: 142.8415 | KL: 55.2341
Epoch 15/50: 100%|██████████| 391/391 [00:13<00:00, 29.25it/s, loss=194.2505, recon=139.8730, kl=54.3775]
  → Avg Loss: 197.0790 | Recon: 141.5088 | KL: 55.5702
Epoch 16/50: 100%|██████████| 391/391 [00:13<00:00, 29.07it/s, loss=195.4708, recon=136.5400, kl=58.9308]
  → Avg Loss: 196.4368 | Recon: 140.4905 | KL: 55.9463
Epoch 17/50: 100%|██████████| 391/391 [00:13<00:00, 28.62it/s, loss=193.3913, recon=136.3209, kl=57.0705]
  → Avg Loss: 195.3040 | Recon: 138.9997 | KL: 56.3043
Epoch 18/50: 100%|██████████| 391/391 [00:13<00:00, 29.22it/s, loss=201.9677, recon=145.5359, kl=56.4318]
  → Avg Loss: 195.2739 | Recon: 138.7295 | KL: 56.5445
Epoch 19/50: 100%|██████████| 391/391 [00:13<00:00, 30.05it/s, loss=194.4581, recon=137.0046, kl=57.4535]
  → Avg Loss: 194.5466 | Recon: 137.6493 | KL: 56.8973
Epoch 20/50: 100%|██████████| 391/391 [00:13<00:00, 29.42it/s, loss=201.7135, recon=143.4993, kl=58.2141]
  → Avg Loss: 193.7772 | Recon: 136.6114 | KL: 57.1658
Epoch 21/50: 100%|██████████| 391/391 [00:13<00:00, 29.57it/s, loss=206.0961, recon=145.7220, kl=60.3741]
  → Avg Loss: 193.2801 | Recon: 135.9017 | KL: 57.3784
Epoch 22/50: 100%|██████████| 391/391 [00:13<00:00, 28.47it/s, loss=206.2248, recon=147.3068, kl=58.9179]
  → Avg Loss: 192.7683 | Recon: 135.2900 | KL: 57.4782
Epoch 23/50: 100%|██████████| 391/391 [00:13<00:00, 28.25it/s, loss=192.3642, recon=134.6950, kl=57.6692]
  → Avg Loss: 192.6962 | Recon: 135.0826 | KL: 57.6136
Epoch 24/50: 100%|██████████| 391/391 [00:13<00:00, 29.13it/s, loss=200.5638, recon=141.0137, kl=59.5501]
  → Avg Loss: 192.2964 | Recon: 134.5822 | KL: 57.7142
Epoch 25/50: 100%|██████████| 391/391 [00:13<00:00, 29.12it/s, loss=194.8343, recon=135.3837, kl=59.4506]
  → Avg Loss: 192.0539 | Recon: 134.1910 | KL: 57.8629
Epoch 26/50: 100%|██████████| 391/391 [00:13<00:00, 28.68it/s, loss=193.1736, recon=133.9527, kl=59.2209]
  → Avg Loss: 191.5099 | Recon: 133.4848 | KL: 58.0252
Epoch 27/50: 100%|██████████| 391/391 [00:13<00:00, 28.49it/s, loss=189.1600, recon=130.5135, kl=58.6465]
  → Avg Loss: 191.2196 | Recon: 133.1057 | KL: 58.1138
Epoch 28/50: 100%|██████████| 391/391 [00:13<00:00, 29.28it/s, loss=191.4100, recon=131.5155, kl=59.8945]
  → Avg Loss: 190.4983 | Recon: 132.2469 | KL: 58.2514
Epoch 29/50: 100%|██████████| 391/391 [00:13<00:00, 29.26it/s, loss=192.9200, recon=136.6108, kl=56.3092]
  → Avg Loss: 190.4235 | Recon: 132.0168 | KL: 58.4067
Epoch 30/50: 100%|██████████| 391/391 [00:13<00:00, 28.36it/s, loss=181.7519, recon=123.3776, kl=58.3743]
  → Avg Loss: 189.9981 | Recon: 131.5563 | KL: 58.4418
Epoch 31/50: 100%|██████████| 391/391 [00:13<00:00, 28.29it/s, loss=182.3616, recon=124.3225, kl=58.0391]
  → Avg Loss: 189.5394 | Recon: 130.8917 | KL: 58.6477
Epoch 32/50: 100%|██████████| 391/391 [00:13<00:00, 28.91it/s, loss=195.2068, recon=134.5835, kl=60.6233]
  → Avg Loss: 189.6287 | Recon: 130.9626 | KL: 58.6661
Epoch 33/50: 100%|██████████| 391/391 [00:13<00:00, 28.68it/s, loss=202.5360, recon=142.3494, kl=60.1866]
  → Avg Loss: 189.2853 | Recon: 130.4664 | KL: 58.8189
Epoch 34/50: 100%|██████████| 391/391 [00:14<00:00, 27.69it/s, loss=184.1093, recon=125.6157, kl=58.4937]
  → Avg Loss: 188.9282 | Recon: 130.0612 | KL: 58.8669
Epoch 35/50: 100%|██████████| 391/391 [00:13<00:00, 28.87it/s, loss=182.9638, recon=124.8147, kl=58.1491]
  → Avg Loss: 188.6524 | Recon: 129.7022 | KL: 58.9502
Epoch 36/50: 100%|██████████| 391/391 [00:13<00:00, 30.06it/s, loss=187.4235, recon=129.7618, kl=57.6618]
  → Avg Loss: 188.6168 | Recon: 129.6613 | KL: 58.9554
Epoch 37/50: 100%|██████████| 391/391 [00:13<00:00, 29.43it/s, loss=193.3752, recon=132.6133, kl=60.7619]
  → Avg Loss: 188.7463 | Recon: 129.7105 | KL: 59.0358
Epoch 38/50: 100%|██████████| 391/391 [00:12<00:00, 30.25it/s, loss=194.3582, recon=133.7615, kl=60.5967]
  → Avg Loss: 188.2282 | Recon: 129.1423 | KL: 59.0859
Epoch 39/50: 100%|██████████| 391/391 [00:13<00:00, 29.73it/s, loss=211.2036, recon=151.3221, kl=59.8815]
  → Avg Loss: 187.7724 | Recon: 128.5591 | KL: 59.2134
Epoch 40/50: 100%|██████████| 391/391 [00:13<00:00, 29.73it/s, loss=190.7648, recon=129.6211, kl=61.1437]
  → Avg Loss: 187.3269 | Recon: 128.0698 | KL: 59.2571
Epoch 41/50: 100%|██████████| 391/391 [00:13<00:00, 29.92it/s, loss=194.8677, recon=135.3041, kl=59.5636]
  → Avg Loss: 187.4758 | Recon: 128.1544 | KL: 59.3214
Epoch 42/50: 100%|██████████| 391/391 [00:12<00:00, 30.18it/s, loss=176.8465, recon=117.8504, kl=58.9961]
  → Avg Loss: 186.9236 | Recon: 127.5445 | KL: 59.3790
Epoch 43/50: 100%|██████████| 391/391 [00:13<00:00, 30.01it/s, loss=177.5479, recon=119.0625, kl=58.4853]
  → Avg Loss: 186.9634 | Recon: 127.5418 | KL: 59.4216
Epoch 44/50: 100%|██████████| 391/391 [00:13<00:00, 29.16it/s, loss=177.7972, recon=117.1134, kl=60.6838]
  → Avg Loss: 186.7969 | Recon: 127.2810 | KL: 59.5160
Epoch 45/50: 100%|██████████| 391/391 [00:13<00:00, 29.98it/s, loss=188.3557, recon=129.2372, kl=59.1185]
  → Avg Loss: 186.5774 | Recon: 127.0470 | KL: 59.5304
Epoch 46/50: 100%|██████████| 391/391 [00:12<00:00, 30.34it/s, loss=188.3152, recon=128.0036, kl=60.3116]
  → Avg Loss: 186.4057 | Recon: 126.8156 | KL: 59.5901
Epoch 47/50: 100%|██████████| 391/391 [00:12<00:00, 30.50it/s, loss=197.4481, recon=136.8469, kl=60.6012]
  → Avg Loss: 186.2367 | Recon: 126.6230 | KL: 59.6136
Epoch 48/50: 100%|██████████| 391/391 [00:12<00:00, 30.25it/s, loss=178.5728, recon=120.2073, kl=58.3655]
  → Avg Loss: 185.9810 | Recon: 126.3210 | KL: 59.6601
Epoch 49/50: 100%|██████████| 391/391 [00:12<00:00, 30.41it/s, loss=184.1438, recon=124.8419, kl=59.3018]
  → Avg Loss: 185.7817 | Recon: 126.0739 | KL: 59.7078
Epoch 50/50: 100%|██████████| 391/391 [00:12<00:00, 30.16it/s, loss=177.4439, recon=118.6713, kl=58.7726]
  → Avg Loss: 185.4197 | Recon: 125.6657 | KL: 59.7539

Training complete!

8. Training Curves Visualization¶

Let's plot how the model improved over time. A good VAE should show decreasing reconstruction loss while KL divergence stabilizes.

In [9]:
def plot_training_curves(history, title="Training Progress"):
    """Plot the training loss curves"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    epochs = range(1, len(history['total_loss']) + 1)
    
    # Total Loss
    axes[0].plot(epochs, history['total_loss'], 'b-', linewidth=2, label='Total Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Total Loss (ELBO)')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    
    # Reconstruction Loss
    axes[1].plot(epochs, history['recon_loss'], 'g-', linewidth=2, label='Recon Loss')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Reconstruction Loss (MSE)')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    
    # KL Divergence
    axes[2].plot(epochs, history['kl_loss'], 'r-', linewidth=2, label='KL Loss')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Loss')
    axes[2].set_title('KL Divergence')
    axes[2].grid(True, alpha=0.3)
    axes[2].legend()
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Training curves saved to 'training_curves.png'")

# Plot the training curves
plot_training_curves(history, "Standard VAE (β=1) Training Progress")
No description has been provided for this image
Training curves saved to 'training_curves.png'

9. Generate Images from Random Noise (16-Image Grid)¶

Now the fun part! Let's see what our VAE "dreams up" when we give it random noise as input. This is the true test of a generative model.

In [10]:
def generate_image_grid(model, num_images=16, nrow=4, title="Generated Images"):
    """Generate a grid of images from random latent vectors"""
    model.eval()
    
    with torch.no_grad():
        # Sample random latent vectors from standard normal
        z = torch.randn(num_images, model.latent_dim).to(device)
        
        # Generate images
        generated = model.decoder(z)
        
        # Denormalize from [-1, 1] to [0, 1]
        generated = generated * 0.5 + 0.5
        generated = generated.cpu()
    
    # Create grid
    fig, axes = plt.subplots(nrow, nrow, figsize=(10, 10))
    
    for i, ax in enumerate(axes.flat):
        if i < num_images:
            img = generated[i].numpy().transpose(1, 2, 0)
            img = np.clip(img, 0, 1)
            ax.imshow(img)
        ax.axis('off')
    
    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('generated_grid.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Generated image grid saved to 'generated_grid.png'")

# Generate 16 random images in a 4x4 grid
print("Generating 16 images from random noise...")
generate_image_grid(model, num_images=16, nrow=4, title="VAE Generated Images (β=1)")
Generating 16 images from random noise...
No description has been provided for this image
Generated image grid saved to 'generated_grid.png'

10. Reconstruction Quality Check¶

Let's also see how well the VAE reconstructs actual images. This helps verify the encoder-decoder pipeline is working correctly.

In [11]:
def show_reconstructions(model, data_loader, num_images=8):
    """Show original images vs their reconstructions"""
    model.eval()
    
    # Get a batch of test images
    images, labels = next(iter(data_loader))
    images = images[:num_images].to(device)
    
    with torch.no_grad():
        recon, _, _ = model(images)
    
    # Denormalize
    images = images * 0.5 + 0.5
    recon = recon * 0.5 + 0.5
    
    # Plot
    fig, axes = plt.subplots(2, num_images, figsize=(14, 4))
    
    for i in range(num_images):
        # Original
        orig = images[i].cpu().numpy().transpose(1, 2, 0)
        axes[0, i].imshow(np.clip(orig, 0, 1))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=12)
        
        # Reconstruction
        rec = recon[i].cpu().numpy().transpose(1, 2, 0)
        axes[1, i].imshow(np.clip(rec, 0, 1))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=12)
    
    plt.suptitle('Original vs Reconstructed Images', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('reconstructions.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Reconstruction comparison saved to 'reconstructions.png'")

# Show reconstructions
print("Comparing original images with their reconstructions...")
show_reconstructions(model, test_loader, num_images=8)
Comparing original images with their reconstructions...
No description has been provided for this image
Reconstruction comparison saved to 'reconstructions.png'

11. Latent Space Interpolation (Task 3)¶

This is where VAEs really shine! We can "morph" between two images by interpolating in the latent space.

The Procedure:

  1. Take two random latent vectors z₁ and z₂
  2. Create 10 intermediate points: z = (1-α)·z₁ + α·z₂ for α ∈ [0, 1]
  3. Decode each point to see the smooth transition

If the latent space is well-organized, we should see meaningful, gradual changes!

In [12]:
def latent_space_interpolation(model, num_steps=10, num_rows=3):
    """
    Perform linear interpolation between random latent vectors.
    
    Creates multiple rows of interpolations to show different transitions.
    Each step is clearly labeled for evaluation.
    """
    model.eval()
    
    fig, axes = plt.subplots(num_rows, num_steps, figsize=(num_steps*1.5, num_rows*1.5))
    
    for row in range(num_rows):
        # Sample two random latent vectors
        z1 = torch.randn(1, model.latent_dim).to(device)
        z2 = torch.randn(1, model.latent_dim).to(device)
        
        # Create interpolation steps (linear interpolation)
        alphas = torch.linspace(0, 1, num_steps)
        
        for i, alpha in enumerate(alphas):
            # Linear interpolation: z = (1-α)*z1 + α*z2
            z_interp = (1 - alpha) * z1 + alpha * z2
            
            # Generate image
            with torch.no_grad():
                img = model.decoder(z_interp)
            
            # Denormalize and display
            img = (img * 0.5 + 0.5).cpu().squeeze(0).numpy().transpose(1, 2, 0)
            img = np.clip(img, 0, 1)
            
            axes[row, i].imshow(img)
            axes[row, i].axis('off')
            
            # Add clear step labels on first row
            if row == 0:
                axes[row, i].set_title(f'Step {i+1}/{num_steps}', fontsize=9, fontweight='bold')
            
            # Add alpha values on last row
            if row == num_rows - 1:
                axes[row, i].set_xlabel(f'α={alpha:.2f}', fontsize=8)
    
    plt.suptitle(f'Latent Space Interpolation ({num_steps} Steps)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('interpolation.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Interpolation visualization saved showing {num_steps} clear steps")

# Perform interpolation
print("Exploring the latent space through interpolation...")
latent_space_interpolation(model, num_steps=10, num_rows=3)
Exploring the latent space through interpolation...
No description has been provided for this image
Interpolation visualization saved showing 10 clear steps

12. β-VAE Experiment (Task 4)¶

Now for the advanced part! β-VAE modifies the loss by adding a weight to the KL divergence term:

$$\mathcal{L}_{\beta-VAE} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{KL}$$

What happens when we increase β?

  • β = 1 (standard VAE): Balance between reconstruction and latent regularization
  • β > 1 (β-VAE): Stronger pressure on latent space → more disentangled features, but potentially blurrier reconstructions

Let's train with β = 5 and compare!

In [13]:
model_beta1 = model
history_beta1 = history

print("=" * 60)
print("TRAINING β-VAE (β = 5)")
print("=" * 60)

model_beta5 = VAE(LATENT_DIM).to(device)
history_beta5 = train_vae(model_beta5, train_loader, EPOCHS, beta=5.0, lr=LEARNING_RATE)

print("\nβ-VAE training complete!")
============================================================
TRAINING β-VAE (β = 5)
============================================================
Epoch 1/50: 100%|██████████| 391/391 [00:12<00:00, 30.76it/s, loss=360.6312, recon=294.3878, kl=13.2487]
  → Avg Loss: 498.0148 | Recon: 438.0461 | KL: 11.9937
Epoch 2/50: 100%|██████████| 391/391 [00:13<00:00, 30.07it/s, loss=347.5817, recon=270.5161, kl=15.4131]
  → Avg Loss: 367.7309 | Recon: 292.5369 | KL: 15.0388
Epoch 3/50: 100%|██████████| 391/391 [00:12<00:00, 30.14it/s, loss=319.2070, recon=244.1280, kl=15.0158]
  → Avg Loss: 349.2755 | Recon: 269.6205 | KL: 15.9310
Epoch 4/50: 100%|██████████| 391/391 [00:12<00:00, 30.39it/s, loss=346.5734, recon=270.2445, kl=15.2658]
  → Avg Loss: 342.5169 | Recon: 260.8127 | KL: 16.3408
Epoch 5/50: 100%|██████████| 391/391 [00:12<00:00, 30.31it/s, loss=338.4818, recon=253.3309, kl=17.0302]
  → Avg Loss: 339.2335 | Recon: 256.1645 | KL: 16.6138
Epoch 6/50: 100%|██████████| 391/391 [00:12<00:00, 30.73it/s, loss=341.4335, recon=257.1042, kl=16.8659]
  → Avg Loss: 336.2771 | Recon: 252.3392 | KL: 16.7876
Epoch 7/50: 100%|██████████| 391/391 [00:12<00:00, 30.24it/s, loss=341.9691, recon=253.5479, kl=17.6842]
  → Avg Loss: 334.0958 | Recon: 249.2468 | KL: 16.9698
Epoch 8/50: 100%|██████████| 391/391 [00:12<00:00, 30.45it/s, loss=308.1805, recon=227.6594, kl=16.1042]
  → Avg Loss: 333.2896 | Recon: 247.8258 | KL: 17.0928
Epoch 9/50: 100%|██████████| 391/391 [00:13<00:00, 29.93it/s, loss=336.5240, recon=250.3893, kl=17.2269]
  → Avg Loss: 331.5870 | Recon: 245.4277 | KL: 17.2319
Epoch 10/50: 100%|██████████| 391/391 [00:13<00:00, 30.04it/s, loss=312.4016, recon=227.1611, kl=17.0481]
  → Avg Loss: 330.3771 | Recon: 243.6973 | KL: 17.3360
Epoch 11/50: 100%|██████████| 391/391 [00:12<00:00, 30.40it/s, loss=336.1519, recon=251.2824, kl=16.9739]
  → Avg Loss: 329.3207 | Recon: 242.1999 | KL: 17.4242
Epoch 12/50: 100%|██████████| 391/391 [00:12<00:00, 30.08it/s, loss=320.1717, recon=229.7546, kl=18.0834]
  → Avg Loss: 328.3229 | Recon: 240.8731 | KL: 17.4899
Epoch 13/50: 100%|██████████| 391/391 [00:13<00:00, 30.03it/s, loss=324.5407, recon=235.9893, kl=17.7103]
  → Avg Loss: 327.2161 | Recon: 239.1024 | KL: 17.6227
Epoch 14/50: 100%|██████████| 391/391 [00:13<00:00, 30.06it/s, loss=326.3479, recon=230.5851, kl=19.1525]
  → Avg Loss: 326.2984 | Recon: 237.9564 | KL: 17.6684
Epoch 15/50: 100%|██████████| 391/391 [00:12<00:00, 30.19it/s, loss=335.2366, recon=248.6060, kl=17.3261]
  → Avg Loss: 325.7186 | Recon: 236.9746 | KL: 17.7488
Epoch 16/50: 100%|██████████| 391/391 [00:12<00:00, 30.34it/s, loss=323.2148, recon=235.5401, kl=17.5349]
  → Avg Loss: 325.1229 | Recon: 235.8434 | KL: 17.8559
Epoch 17/50: 100%|██████████| 391/391 [00:13<00:00, 29.72it/s, loss=312.9568, recon=225.3151, kl=17.5283]
  → Avg Loss: 324.5417 | Recon: 235.1915 | KL: 17.8700
Epoch 18/50: 100%|██████████| 391/391 [00:13<00:00, 29.77it/s, loss=319.3827, recon=232.0301, kl=17.4705]
  → Avg Loss: 323.9863 | Recon: 234.4843 | KL: 17.9004
Epoch 19/50: 100%|██████████| 391/391 [00:13<00:00, 29.39it/s, loss=301.7355, recon=214.6904, kl=17.4090]
  → Avg Loss: 323.6265 | Recon: 233.8999 | KL: 17.9453
Epoch 20/50: 100%|██████████| 391/391 [00:13<00:00, 29.47it/s, loss=329.6045, recon=237.8358, kl=18.3537]
  → Avg Loss: 322.9839 | Recon: 233.0444 | KL: 17.9879
Epoch 21/50: 100%|██████████| 391/391 [00:13<00:00, 29.98it/s, loss=333.3153, recon=240.2998, kl=18.6031]
  → Avg Loss: 322.8141 | Recon: 232.8469 | KL: 17.9934
Epoch 22/50: 100%|██████████| 391/391 [00:13<00:00, 29.99it/s, loss=345.1356, recon=248.7317, kl=19.2808]
  → Avg Loss: 322.3970 | Recon: 232.0943 | KL: 18.0605
Epoch 23/50: 100%|██████████| 391/391 [00:12<00:00, 30.31it/s, loss=320.1385, recon=226.9476, kl=18.6382]
  → Avg Loss: 321.7056 | Recon: 231.3279 | KL: 18.0755
Epoch 24/50: 100%|██████████| 391/391 [00:12<00:00, 30.20it/s, loss=333.8743, recon=244.2642, kl=17.9220]
  → Avg Loss: 321.2803 | Recon: 230.8729 | KL: 18.0815
Epoch 25/50: 100%|██████████| 391/391 [00:12<00:00, 30.38it/s, loss=339.6066, recon=246.6411, kl=18.5931]
  → Avg Loss: 321.4419 | Recon: 230.7144 | KL: 18.1455
Epoch 26/50: 100%|██████████| 391/391 [00:13<00:00, 29.98it/s, loss=329.6647, recon=236.3800, kl=18.6569]
  → Avg Loss: 320.5731 | Recon: 229.7550 | KL: 18.1636
Epoch 27/50: 100%|██████████| 391/391 [00:13<00:00, 29.89it/s, loss=299.0518, recon=206.9067, kl=18.4290]
  → Avg Loss: 320.5853 | Recon: 229.7333 | KL: 18.1704
Epoch 28/50: 100%|██████████| 391/391 [00:12<00:00, 30.09it/s, loss=328.3957, recon=238.6541, kl=17.9483]
  → Avg Loss: 320.1299 | Recon: 229.1959 | KL: 18.1868
Epoch 29/50: 100%|██████████| 391/391 [00:12<00:00, 30.20it/s, loss=347.8535, recon=256.0664, kl=18.3574]
  → Avg Loss: 319.8510 | Recon: 228.6904 | KL: 18.2321
Epoch 30/50: 100%|██████████| 391/391 [00:12<00:00, 30.09it/s, loss=297.7848, recon=207.8775, kl=17.9815]
  → Avg Loss: 319.5578 | Recon: 228.2528 | KL: 18.2610
Epoch 31/50: 100%|██████████| 391/391 [00:12<00:00, 30.08it/s, loss=316.4286, recon=223.8820, kl=18.5093]
  → Avg Loss: 318.9221 | Recon: 227.5597 | KL: 18.2725
Epoch 32/50: 100%|██████████| 391/391 [00:12<00:00, 30.15it/s, loss=328.6658, recon=240.5954, kl=17.6141]
  → Avg Loss: 319.1293 | Recon: 227.6107 | KL: 18.3037
Epoch 33/50: 100%|██████████| 391/391 [00:12<00:00, 30.29it/s, loss=298.4813, recon=207.7271, kl=18.1508]
  → Avg Loss: 319.0369 | Recon: 227.2459 | KL: 18.3582
Epoch 34/50: 100%|██████████| 391/391 [00:12<00:00, 30.36it/s, loss=341.5300, recon=248.1841, kl=18.6692]
  → Avg Loss: 318.4974 | Recon: 226.8547 | KL: 18.3285
Epoch 35/50: 100%|██████████| 391/391 [00:13<00:00, 29.88it/s, loss=341.9253, recon=250.6912, kl=18.2468]
  → Avg Loss: 317.9969 | Recon: 226.4132 | KL: 18.3167
Epoch 36/50: 100%|██████████| 391/391 [00:12<00:00, 30.41it/s, loss=317.9156, recon=224.1446, kl=18.7542]
  → Avg Loss: 317.6670 | Recon: 226.0566 | KL: 18.3221
Epoch 37/50: 100%|██████████| 391/391 [00:13<00:00, 30.03it/s, loss=324.7522, recon=231.1570, kl=18.7190]
  → Avg Loss: 318.2747 | Recon: 226.3808 | KL: 18.3788
Epoch 38/50: 100%|██████████| 391/391 [00:12<00:00, 30.24it/s, loss=319.9552, recon=230.8044, kl=17.8302]
  → Avg Loss: 317.6491 | Recon: 225.9635 | KL: 18.3371
Epoch 39/50: 100%|██████████| 391/391 [00:12<00:00, 30.66it/s, loss=320.2074, recon=226.5878, kl=18.7239]
  → Avg Loss: 317.2274 | Recon: 225.3707 | KL: 18.3713
Epoch 40/50: 100%|██████████| 391/391 [00:13<00:00, 29.85it/s, loss=322.7083, recon=228.8596, kl=18.7698]
  → Avg Loss: 316.9150 | Recon: 224.7879 | KL: 18.4254
Epoch 41/50: 100%|██████████| 391/391 [00:12<00:00, 30.34it/s, loss=315.1140, recon=219.7717, kl=19.0685]
  → Avg Loss: 316.5589 | Recon: 224.5129 | KL: 18.4092
Epoch 42/50: 100%|██████████| 391/391 [00:12<00:00, 30.08it/s, loss=318.6714, recon=224.2823, kl=18.8778]
  → Avg Loss: 316.4340 | Recon: 224.3202 | KL: 18.4228
Epoch 43/50: 100%|██████████| 391/391 [00:12<00:00, 30.47it/s, loss=306.1914, recon=214.0939, kl=18.4195]
  → Avg Loss: 316.5797 | Recon: 224.3206 | KL: 18.4518
Epoch 44/50: 100%|██████████| 391/391 [00:13<00:00, 29.75it/s, loss=317.1483, recon=226.1165, kl=18.2064]
  → Avg Loss: 316.6710 | Recon: 224.3024 | KL: 18.4737
Epoch 45/50: 100%|██████████| 391/391 [00:13<00:00, 29.32it/s, loss=306.4611, recon=216.0226, kl=18.0877]
  → Avg Loss: 316.3152 | Recon: 223.8738 | KL: 18.4883
Epoch 46/50: 100%|██████████| 391/391 [00:13<00:00, 28.73it/s, loss=314.2213, recon=224.1857, kl=18.0071]
  → Avg Loss: 315.9602 | Recon: 223.4631 | KL: 18.4994
Epoch 47/50: 100%|██████████| 391/391 [00:13<00:00, 29.06it/s, loss=330.5685, recon=237.2563, kl=18.6624]
  → Avg Loss: 315.7213 | Recon: 223.2480 | KL: 18.4947
Epoch 48/50: 100%|██████████| 391/391 [00:13<00:00, 30.01it/s, loss=310.8220, recon=214.5401, kl=19.2564]
  → Avg Loss: 315.7171 | Recon: 223.1837 | KL: 18.5067
Epoch 49/50: 100%|██████████| 391/391 [00:13<00:00, 29.29it/s, loss=303.8752, recon=212.2518, kl=18.3247]
  → Avg Loss: 315.4992 | Recon: 222.9770 | KL: 18.5044
Epoch 50/50: 100%|██████████| 391/391 [00:13<00:00, 29.42it/s, loss=322.4139, recon=229.5761, kl=18.5676]
  → Avg Loss: 315.4138 | Recon: 222.5905 | KL: 18.5647

β-VAE training complete!

In [14]:
def plot_beta_comparison(history_beta1, history_beta5):
    """Create side-by-side comparison plots for β=1 vs β=5 training"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    epochs = range(1, len(history_beta1['total_loss']) + 1)
    
    # Total Loss
    axes[0].plot(epochs, history_beta1['total_loss'], 'b-', label='β=1', linewidth=2)
    axes[0].plot(epochs, history_beta5['total_loss'], 'r-', label='β=5', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=11)
    axes[0].set_ylabel('Loss', fontsize=11)
    axes[0].set_title('Total Loss Comparison', fontsize=12, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # Reconstruction Loss
    axes[1].plot(epochs, history_beta1['recon_loss'], 'b-', label='β=1', linewidth=2)
    axes[1].plot(epochs, history_beta5['recon_loss'], 'r-', label='β=5', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=11)
    axes[1].set_ylabel('Loss', fontsize=11)
    axes[1].set_title('Reconstruction Loss Comparison', fontsize=12, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    # KL Divergence
    axes[2].plot(epochs, history_beta1['kl_loss'], 'b-', label='β=1', linewidth=2)
    axes[2].plot(epochs, history_beta5['kl_loss'], 'r-', label='β=5', linewidth=2)
    axes[2].set_xlabel('Epoch', fontsize=11)
    axes[2].set_ylabel('Loss', fontsize=11)
    axes[2].set_title('KL Divergence Comparison', fontsize=12, fontweight='bold')
    axes[2].legend(fontsize=10)
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle('β-VAE Training Comparison: Direct Overlay', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('beta_comparison_overlay.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("β comparison overlay saved to 'beta_comparison_overlay.png'")

print("β comparison plotting function defined")
β comparison plotting function defined
In [15]:
# Plot individual β=5 training curves
print("Plotting β=5 training curves...")
plot_training_curves(history_beta5, "β-VAE (β=5) Training Progress")
Plotting β=5 training curves...
No description has been provided for this image
Training curves saved to 'training_curves.png'
In [16]:
# Plot overlay comparison of both models
print("Creating overlay comparison plot...")
plot_beta_comparison(history_beta1, history_beta5)
Creating overlay comparison plot...
No description has been provided for this image
β comparison overlay saved to 'beta_comparison_overlay.png'

13. Comparing β=1 vs β=5¶

Let's compare the two models side-by-side to see the impact of the β parameter.

In [17]:
# Compare training curves
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

epochs = range(1, len(history_beta1['total_loss']) + 1)

# Row 1: β = 1
axes[0, 0].plot(epochs, history_beta1['total_loss'], 'b-', linewidth=2)
axes[0, 0].set_title('Total Loss (β=1)')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(epochs, history_beta1['recon_loss'], 'g-', linewidth=2)
axes[0, 1].set_title('Reconstruction Loss (β=1)')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].grid(True, alpha=0.3)

axes[0, 2].plot(epochs, history_beta1['kl_loss'], 'r-', linewidth=2)
axes[0, 2].set_title('KL Loss (β=1)')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].grid(True, alpha=0.3)

# Row 2: β = 5
axes[1, 0].plot(epochs, history_beta5['total_loss'], 'b-', linewidth=2)
axes[1, 0].set_title('Total Loss (β=5)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(epochs, history_beta5['recon_loss'], 'g-', linewidth=2)
axes[1, 1].set_title('Reconstruction Loss (β=5)')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].grid(True, alpha=0.3)

axes[1, 2].plot(epochs, history_beta5['kl_loss'], 'r-', linewidth=2)
axes[1, 2].set_title('KL Loss (β=5)')
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle('Training Comparison: Standard VAE vs β-VAE', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('beta_comparison_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("Comparison curves saved to 'beta_comparison_curves.png'")
No description has been provided for this image
Comparison curves saved to 'beta_comparison_curves.png'
In [18]:
# Compare generated images side by side
def compare_generations(model1, model2, title1="β=1", title2="β=5", num_images=8):
    """Generate images from both models for comparison"""
    model1.eval()
    model2.eval()
    
    # Use the SAME random seed for fair comparison
    torch.manual_seed(42)
    z = torch.randn(num_images, LATENT_DIM).to(device)
    
    with torch.no_grad():
        gen1 = model1.decoder(z)
        gen2 = model2.decoder(z)
    
    # Denormalize
    gen1 = (gen1 * 0.5 + 0.5).cpu()
    gen2 = (gen2 * 0.5 + 0.5).cpu()
    
    # Plot
    fig, axes = plt.subplots(2, num_images, figsize=(14, 4))
    
    for i in range(num_images):
        img1 = gen1[i].numpy().transpose(1, 2, 0)
        img2 = gen2[i].numpy().transpose(1, 2, 0)
        
        axes[0, i].imshow(np.clip(img1, 0, 1))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel(title1, fontsize=12, rotation=0, labelpad=30)
        
        axes[1, i].imshow(np.clip(img2, 0, 1))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel(title2, fontsize=12, rotation=0, labelpad=30)
    
    plt.suptitle('Generated Image Comparison (Same Latent Vectors)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('beta_comparison_images.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Image comparison saved to 'beta_comparison_images.png'")

# Compare!
print("Comparing generated images from both models...")
compare_generations(model_beta1, model_beta5)
Comparing generated images from both models...
No description has been provided for this image
Image comparison saved to 'beta_comparison_images.png'
In [19]:
# Compare reconstruction quality
def compare_reconstructions(model1, model2, data_loader, title1="β=1", title2="β=5", num_images=6):
    """Compare reconstruction quality of both models"""
    model1.eval()
    model2.eval()
    
    images, _ = next(iter(data_loader))
    images = images[:num_images].to(device)
    
    with torch.no_grad():
        recon1, _, _ = model1(images)
        recon2, _, _ = model2(images)
    
    # Denormalize
    images = (images * 0.5 + 0.5).cpu()
    recon1 = (recon1 * 0.5 + 0.5).cpu()
    recon2 = (recon2 * 0.5 + 0.5).cpu()
    
    # Plot
    fig, axes = plt.subplots(3, num_images, figsize=(12, 6))
    
    for i in range(num_images):
        orig = images[i].numpy().transpose(1, 2, 0)
        r1 = recon1[i].numpy().transpose(1, 2, 0)
        r2 = recon2[i].numpy().transpose(1, 2, 0)
        
        axes[0, i].imshow(np.clip(orig, 0, 1))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Original', fontsize=10, rotation=0, labelpad=40)
        
        axes[1, i].imshow(np.clip(r1, 0, 1))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel(f'Recon\n({title1})', fontsize=10, rotation=0, labelpad=40)
        
        axes[2, i].imshow(np.clip(r2, 0, 1))
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_ylabel(f'Recon\n({title2})', fontsize=10, rotation=0, labelpad=40)
    
    plt.suptitle('Reconstruction Quality Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('beta_comparison_recon.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Reconstruction comparison saved to 'beta_comparison_recon.png'")

compare_reconstructions(model_beta1, model_beta5, test_loader)
No description has been provided for this image
Reconstruction comparison saved to 'beta_comparison_recon.png'
In [20]:
print("Generating 16-image grid for β-VAE (β=5)...")
generate_image_grid(model_beta5, num_images=16, nrow=4, title="β-VAE Generated Images (β=5)")
Generating 16-image grid for β-VAE (β=5)...
No description has been provided for this image
Generated image grid saved to 'generated_grid.png'

Observations on β-VAE (β=1 vs β=5)¶

Impact on Image Quality:

When I compared the standard VAE (β=1) with the β-VAE (β=5), I noticed a clear trade-off between how well the images are reconstructed and how organized the latent space is. The standard VAE produces sharper, more detailed images that better preserve fine textures and edges. In contrast, the β-VAE creates noticeably blurrier images. This happens because the stronger KL penalty forces the model to compress information more aggressively into the latent space.

Impact on Latent Space Organization:

A higher β value leads to better separation in the latent space. This means different latent dimensions capture more independent features, like color separate from shape or orientation separate from object type. During interpolation experiments, the β-VAE produced smoother and more meaningful transitions between images compared to the standard VAE, which sometimes showed inconsistent intermediate states.

The Trade-off Explained:

The β parameter manages the balance between two competing goals:

  1. Reconstruction accuracy (making the output look like the input)
  2. Latent regularization (structuring the latent space like N(0,I))

With β=1, both goals are weighted equally. With β=5, we penalize deviations from the prior significantly, causing the encoder to "forget" some details to keep the latent space well-organized. This explains why β-VAE images appear blurrier, while the latent representations are easier to interpret.

Practical Implications:

For applications that require high visual quality, like image restoration or super-resolution, a lower β is better. For tasks that need interpretable or controllable generation, such as style transfer or attribute manipulation, higher β values offer better feature separation, even if there is some loss in quality.

Conclusion:

The β-VAE experiment highlights the fundamental conflict in generative modeling between expressiveness and regularity. This is an important concept for understanding variational inference.

15. Final Summary & Saved Outputs¶

Submission¶

  1. Full source code - Encoder, Decoder, VAE, and Loss function
  2. Training curves - training_curves.png
  3. 16-image grid - generated_grid.png
  4. 10-step interpolation - interpolation.png
  5. β-VAE analysis - Written summary above
In [21]:
print("=" * 60)
print("FINAL TRAINING STATISTICS")
print("=" * 60)

print("\nStandard VAE (β=1):")
print(f"   Final Total Loss: {history_beta1['total_loss'][-1]:.4f}")
print(f"   Final Recon Loss: {history_beta1['recon_loss'][-1]:.4f}")
print(f"   Final KL Loss: {history_beta1['kl_loss'][-1]:.4f}")

print("\nβ-VAE (β=5):")
print(f"   Final Total Loss: {history_beta5['total_loss'][-1]:.4f}")
print(f"   Final Recon Loss: {history_beta5['recon_loss'][-1]:.4f}")
print(f"   Final KL Loss: {history_beta5['kl_loss'][-1]:.4f}")

print("\n" + "=" * 60)
print("Assignment Complete!")
print("=" * 60)
============================================================
FINAL TRAINING STATISTICS
============================================================

Standard VAE (β=1):
   Final Total Loss: 185.4197
   Final Recon Loss: 125.6657
   Final KL Loss: 59.7539

β-VAE (β=5):
   Final Total Loss: 315.4138
   Final Recon Loss: 222.5905
   Final KL Loss: 18.5647

============================================================
Assignment Complete!
============================================================
In [22]:
# Save individual model weights locally
torch.save(model_beta1.state_dict(), 'vae_beta1_weights.pth')
torch.save(model_beta5.state_dict(), 'vae_beta5_weights.pth')

print("Saved weights to 'vae_beta1_weights.pth' and 'vae_beta5_weights.pth'")
Saved weights to 'vae_beta1_weights.pth' and 'vae_beta5_weights.pth'
In [23]:
torch.save({
    'model_beta1_state_dict': model_beta1.state_dict(),
    'model_beta5_state_dict': model_beta5.state_dict(),
    'history_beta1': history_beta1,
    'history_beta5': history_beta5,
    'latent_dim': LATENT_DIM,
    'epochs': EPOCHS,
}, 'vae_models.pth')

print("Models saved to 'vae_models.pth'")
Models saved to 'vae_models.pth'

16. Display All Generated Images¶

Let's view all the saved outputs from our training!

In [24]:
from PIL import Image
import os

def display_saved_image(filepath, title):
    """Display a saved image if it exists"""
    if os.path.exists(filepath):
        img = Image.open(filepath)
        plt.figure(figsize=(15, 10))
        plt.imshow(img)
        plt.axis('off')
        plt.title(title, fontsize=16, fontweight='bold', pad=20)
        plt.tight_layout()
        plt.show()
        print(f"✓ Displayed: {filepath}\n")
    else:
        print(f"✗ File not found: {filepath}\n")

print("=" * 60)
print("DISPLAYING ALL SAVED OUTPUTS")
print("=" * 60)
print()
============================================================
DISPLAYING ALL SAVED OUTPUTS
============================================================

1. Training Curves (Standard VAE β=1)¶

In [25]:
display_saved_image('training_curves.png', 'Training Progress: Standard VAE (β=1)')
No description has been provided for this image
✓ Displayed: training_curves.png

2. Generated Images (16-Image Grid)¶

In [26]:
display_saved_image('generated_grid.png', 'VAE Generated Images (Random Noise → Images)')
No description has been provided for this image
✓ Displayed: generated_grid.png

3. Image Reconstructions¶

In [27]:
display_saved_image('reconstructions.png', 'Original vs Reconstructed Images')
No description has been provided for this image
✓ Displayed: reconstructions.png

4. Latent Space Interpolation (10 Steps)¶

In [28]:
display_saved_image('interpolation.png', 'Latent Space Interpolation: Smooth Morphing Between Images')
No description has been provided for this image
✓ Displayed: interpolation.png

5. β-VAE Comparison: Training Curves¶

In [29]:
display_saved_image('beta_comparison_curves.png', 'β-VAE Training Comparison: β=1 vs β=5')
No description has been provided for this image
✓ Displayed: beta_comparison_curves.png

6. β-VAE Comparison: Generated Images¶

In [30]:
display_saved_image('beta_comparison_images.png', 'Generated Images: Same Latent Vectors, Different β')
No description has been provided for this image
✓ Displayed: beta_comparison_images.png

7. β-VAE Comparison: Reconstruction Quality¶

In [31]:
display_saved_image('beta_comparison_recon.png', 'Reconstruction Quality: β=1 vs β=5')
No description has been provided for this image
✓ Displayed: beta_comparison_recon.png